4-th Exercise - Embeddings, Vector Databases & SearchΒΆ
We will use FAISS and CLIP to create and search pre-trained embedding vectors.
First, we need to install required libraries:
- CLIP: For generating embeddings
- FAISS: For indexing
import sys
# Uncomment as needed
# !{sys.executable} -m pip install --no-input openai-clip
# Either use CPU-only version
# !conda install --yes --prefix {sys.prefix} -c pytorch faiss-cpu=1.9.0
# Alternatively use the GPU(+CPU) version
# !conda install --yes --prefix {sys.prefix} -c pytorch -c nvidia faiss-gpu=1.9.0
# !conda install --yes --prefix {sys.prefix}
Load required librariesΒΆ
import os
# Avoid different versions of numpy make kernel collaps
os.environ['KMP_DUPLICATE_LIB_OK']='True'
import torch
import faiss
import clip
import json
import requests
from PIL import Image
from io import BytesIO
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import matplotlib.image as mpimg
from matplotlib.patches import Rectangle
from tqdm import tqdm
Some helper methods for testing resultsΒΆ
images_URL = "http://vision.stanford.edu/aditya86/ImageNetDogs/images/"
def load_embeddings(embeddings_path):
df = pd.read_csv(embeddings_path, sep=';', compression='gzip', index_col=0)
# Next, we will convert the embeddings String-column into a numpy-vector
start_time = time.time()
data = df["embedding"].apply(json.loads).values
embeddings = np.zeros((data.shape[0], len(data[0])), dtype=np.float32)
for i, d in enumerate(data):
embeddings[i] = d
end_time = time.time()
print(f"Time to load {np.round(end_time-start_time,3)}s")
return df, embeddings
# Calculate Precision@k
def precision_at_k(indices, k):
y_true = df.iloc[indices[:,0]]["class"].values
y_pred = df.iloc[indices.flatten()]["class"].values.reshape(indices.shape[0], -1)
precisions = []
# Check how many of the retrieved items are relevant
for i, pred in enumerate(y_pred[:, :k]):
relevant = np.sum(y_true[i] == pred)
precisions.append(relevant / k)
return np.mean(precisions)
# Plot search results
def plot_results(query_file_names, indices, k):
if indices is not None:
for pos, ind in enumerate(indices):
fig, ax = plt.subplots(1, k+1, figsize=(3*(k+1),4))
if query_file_names:
path_query = f'{query_file_names[pos]}'
if os.path.isfile(path_query):
ax[0].imshow(mpimg.imread(path_query))
ax[0].axis('off')
ax[0].set_title("Query")
elif len(query_file_names) > pos:
# ax[0].set_title()
ax[0].text(0.5, 0.5,
f"{query_file_names[pos]}", fontsize=20,
ha='center', va='center')
ax[0].axis('off')
else:
ax[0].remove()
for i in range(0, k):
d_train = df.iloc[ind[i]]
url = f'{images_URL}/{d_train["dir"]}/{d_train["filename"]}.jpg'
response = requests.get(url, stream=True)
if response.status_code == 200:
img = Image.open(BytesIO(response.content))
ax[1+i].imshow(img)
ax[1+i].axis('off')
ax[1+i].set_title(f"{i+1}-NN: " + d_train["class"])
xy = (d_train["xmin"], d_train["ymin"])
width = d_train["xmax"]-d_train["xmin"]
height = d_train["ymax"]-d_train["ymin"]
rect = Rectangle(xy, width, height, edgecolor='white', fill=None)
ax[1+i].add_patch(rect)
ax[1+i].text(xy[0], xy[1], d_train["class"], fontsize=12, color='white',
verticalalignment='bottom', horizontalalignment='left')
else:
ax[i+1].remove()
plt.tight_layout()
plt.show()
def plot_precision_at_k(indices, k):
precisions = [precision_at_k(indices, i) for i in range (1,k+1)]
for i, precision in enumerate(precisions):
print(f"precision@{i}: {precision}")
fig, ax = plt.subplots(figsize=(8,4))
ax.set_title("Precision@k")
ax.plot(np.arange(1,k+1, dtype=np.int32), precisions)
ax.set_xlabel("k")
ax.set_ylabel("Precision@k")
plt.show()
Load Embeddings and Meta-DataΒΆ
You must download the requires files first, and adapt paths:
queries/: Queries needed for Tasks 1-4. Load from Moodledata/vectors.csv.gz: Embeddings needed for Tasks 1-5. Load from Moodle
### modify this path, if needed!
queries_start_path = "queries"
embeddings_path = 'data/vectors.csv.gz'
df, embeddings = load_embeddings(embeddings_path)
embedding_dim = embeddings.shape[1]
df.head()
Time to load 3.362s
| filename | dir | class | pose | xmin | ymin | xmax | ymax | embedding | |
|---|---|---|---|---|---|---|---|---|---|
| 0 | n02097658_26 | n02097658-silky_terrier | silky_terrier | Unspecified | 41 | 30 | 296 | 398 | [0.21719056367874146, 0.29560884833335876, 0.0... |
| 1 | n02097658_4869 | n02097658-silky_terrier | silky_terrier | Unspecified | 43 | 270 | 321 | 498 | [-0.3137703835964203, -0.5720250606536865, 0.0... |
| 2 | n02097658_595 | n02097658-silky_terrier | silky_terrier | Unspecified | 82 | 7 | 378 | 355 | [0.10297045111656189, -0.08061640709638596, -0... |
| 3 | n02097658_9222 | n02097658-silky_terrier | silky_terrier | Unspecified | 0 | 12 | 331 | 498 | [0.20631632208824158, 0.0758294016122818, 0.02... |
| 4 | n02097658_422 | n02097658-silky_terrier | silky_terrier | Unspecified | 146 | 10 | 356 | 332 | [0.02929862029850483, -0.12727777659893036, 0.... |
Task 1: Exact Search using Flat IndexΒΆ
FAISS implements several metrics and indices for similarity search.
Your task is to index all embeddings using IndexFlatL2 index first.
faiss.IndexFlatL2(d)
...
You can read more about their implementation and how to use IndexFlatL2 on their GitHub page: https://github.com/facebookresearch/faiss/wiki
def create_IndexFlatL2(embeddings):
# Answer:
# faiss.IndexFlatL2(dimension)
index = faiss.IndexFlatL2(embeddings.shape[1])
return index
start_time = time.time()
index_flat = create_IndexFlatL2(embeddings)
end_time = time.time()
print(f"Time to index {np.round(end_time-start_time,3)}s")
Time to index 0.001s
Task 2: Generate embeddings for queries using OpenAI's CLIPΒΆ
You are given a set of images in the queries subfolder. Your task is to:
- Generate CLIP-embedddings for each image
- Search for the image embeddings using the FAISS index
CLIP [1] an embedding model by OpenAI, which is used to extract a high-dimensional vector representation of an image, which captures its semantic and perceptual features.
Initialize OpenAi's CLIP Encoder for feature extractionΒΆ
device = "cuda" if torch.cuda.is_available() else "cpu"
print("running on: " + device)
model, preprocess = clip.load("ViT-B/32", device=device)
running on: cuda
Now Generate Embeddings of QueriesΒΆ
Use CLIP to generate for all images in the queries folder the corresponding embedding.
model.encode_image(image)
See: https://github.com/openai/CLIP
Upon generation, make sure to convert all data to float32, which is required to FAISS.
def create_query_embeddings(queries_start_path):
# Load all images from the queries subfolder
df_data = pd.DataFrame()
query_file_names = []
for root, dirs, files in os.walk(queries_start_path):
for file_name in files:
if ".jpg" in file_name:
query_file_names.append(f"{root}/{file_name}")
#query_file_names.append(file_name)
# Answer:
query_embeddings = np.zeros((0,embeddings.shape[1]), dtype=np.float32)
for query_file_name in query_file_names:
image = preprocess(Image.open(query_file_name)).unsqueeze(0).to(device)
with torch.no_grad():
image_features = model.encode_image(image)
query_embeddings = np.vstack((query_embeddings, image_features.cpu().numpy()))
return query_embeddings, query_file_names
query_embeddings, query_file_names = create_query_embeddings(queries_start_path)
Querying: Search for top-3 relevant imagesΒΆ
We need to define a search function below to first vectorize our query image, and then search for the vectors with the closest distance.
index_flat.search(query_embeddings, k)
k = 3
def indexed_search(index, query_embeddings, k):
# Search for the top-3 nearest neighbors
# Answer:
if not index.is_trained:
print("Training index...")
index.train(embeddings)
if index.ntotal == 0:
print("Adding embeddings to index...")
index.add(embeddings)
distances, indices = index.search(query_embeddings, k)
return indices
start_time = time.time()
flat_indices = indexed_search(index_flat, query_embeddings, k)
end_time = time.time()
print(f"Time to index {np.round(end_time-start_time,3)}s")
Adding embeddings to index... Time to index 0.021s
Plot resultsΒΆ
We use the helper function, to illustrate the results
plot_results(query_file_names, flat_indices, k)
Measure precision@kΒΆ
We measure the precision@k for different values of k.
# A simple test to test correctness
if flat_indices is None:
print("You must implement Task 1 and 2 first")
else:
assert precision_at_k(flat_indices, 1) == 1.0
plot_precision_at_k(flat_indices, k)
precision@0: 1.0 precision@1: 0.7 precision@2: 0.6666666666666667
Task 3: Use Hierarchical Small World GraphsΒΆ
To use Hierarchical Small World Graphs in FAISS, we need to work with the HNSW (Hierarchical Navigable Small World) index that FAISS provides. HNSW is a type of small-world graph used for approximate nearest neighbor search, which is highly efficient for high-dimensional data.
index = faiss.IndexHNSWFlat(d, M)
hswg_index.hnsw.efConstruction = ...
hswg_index.hnsw.efSearch = ...
Tuning the HNSW Index:
M: A higher value will give more accurate results but with increased memory usage and slower indexing times.EfConstruction: This parameter controls the number of candidate neighbors that are evaluated during the construction of the graph. A larger value can lead to better quality of the graph at the expense of longer construction time.EfSearch: This parameter controls the number of candidate neighbors considered during search. A larger value can improve accuracy but makes the search slower.
k=3
def index_and_query_hswg(embeddings, query_embeddings, k):
# Answer:
dim = embeddings.shape[1]
hswg_index = faiss.IndexHNSWFlat(dim, 32) # 32 is the number of neighbors to explore. Bigger is more accurate, but slower
hswg_index.hnsw.efConstruction = 40 # 40 is the number of k-NN, when the graph is constructed. Bigger is more accurate, but slower
hswg_index.hnsw.efSearch = 64 # 64 is the number of candidates to explore during the search. Biggers is more accurate, but slower
hswg_index.add(embeddings)
distances, indices = hswg_index.search(query_embeddings, k) # used to store the result
return hswg_index, indices
start_time = time.time()
hswg_index, hswg_indices = index_and_query_hswg(embeddings, query_embeddings, k)
end_time = time.time()
print(f"Time to index and query {np.round(end_time-start_time,3)}s")
Time to index and query 0.509s
# A simple test to test correctness
if hswg_indices is None:
print("You must implement Task 3 first")
else:
assert precision_at_k(hswg_indices, 1) == 1.0
plot_precision_at_k(hswg_indices, k)
precision@0: 1.0 precision@1: 0.7 precision@2: 0.6666666666666667
Task 4: Optimize Hyper-ParamerersΒΆ
Test different ranges of the three input parameters
M_values = [8, 16, 32]
efConstruction_values = [40, 100, 200]
efSearch_values = [10, 50, 100]
metric_types = [faiss.METRIC_L2, faiss.METRIC_L1]
For each parameter, create a IndexHNSWFlat, and search for the $5$-NN.
For each set of parameters, record the query-time and the precision@5. Output the best 5 sets of parameters for the best quer-times and the best precision@5.
You may measure runtimes using:
start_time = time.time()
...
end_time = time.time()
# Range of parameters to try for tuning
M_values = [8, 16, 32]
efConstruction_values = [40, 100, 200]
efSearch_values = [10, 50, 100]
metric_types = [faiss.METRIC_L2, faiss.METRIC_L1]
k = 5 # Number of neighbors
def hswg_grid_search(embeddings, query_embeddings, k):
# Answer:
# For each set of parameters, record (a) the query-time and (b) the precision@5.
# Output the best 5 sets of parameters for the best quer-times and the best precision@5.
results = []
for M in M_values:
for efConstruction in efConstruction_values:
for efSearch in efSearch_values:
for metric in metric_types:
hswg_index = faiss.IndexHNSWFlat(embeddings.shape[1], M, metric)
hswg_index.hnsw.efConstruction = efConstruction
hswg_index.hnsw.efSearch = efSearch
hswg_index.reset()
if not hswg_index.is_trained:
hswg_index.train(embeddings)
if hswg_index.ntotal == 0:
print("Adding embeddings to index...")
hswg_index.add(embeddings)
start_time = time.time()
distances, hswg_indices = hswg_index.search(query_embeddings, k)
end_time = time.time()
query_time = end_time - start_time
precision = precision_at_k(hswg_indices, k)
results.append((M, efConstruction, efSearch, metric, query_time, precision))
print(f"M={M}, efConstruction={efConstruction}, efSearch={efSearch}, metric={metric}: query_time={query_time}, precision@5={precision}")
# Best performance tracking
best_params_query_time, best_params_precision_at_5 = [], []
best_params_query_time = sorted(results, key=lambda x: x[4])[:5]
best_params_precision_at_5 = sorted(results, key=lambda x: x[5], reverse=True)[:5]
return best_params_query_time, best_params_precision_at_5
best_params_query_time, best_params_precision_at_5 = hswg_grid_search(embeddings, query_embeddings, k)
# print ("Best Query Time", best_params_query_time)
# Best Query Time Results
print("\n=== Best Query Time Results ===")
for i, result in enumerate(best_params_query_time, 1):
print(f"{i}. Time: {result[4]:.4f}s, M={result[0]}, efC={result[1]}, efS={result[2]}, Metric={'L2' if result[3]==faiss.METRIC_L2 else 'L1'}")
# Best Precision Results
# print("\n=== Best Precision Results ===")
for i, result in enumerate(best_params_precision_at_5, 1):
print(f"{i}. Precision: {result[5]:.4f}, M={result[0]}, efC={result[1]}, efS={result[2]}, Metric={'L2' if result[3]==faiss.METRIC_L2 else 'L1'}")
print ("Best Precision@5", best_params_precision_at_5)
Adding embeddings to index... M=8, efConstruction=40, efSearch=10, metric=1: query_time=0.0, precision@5=0.4800000000000001 Adding embeddings to index... M=8, efConstruction=40, efSearch=10, metric=2: query_time=0.0, precision@5=0.4600000000000001 Adding embeddings to index... M=8, efConstruction=40, efSearch=50, metric=1: query_time=0.0, precision@5=0.5000000000000001 Adding embeddings to index... M=8, efConstruction=40, efSearch=50, metric=2: query_time=0.0010035037994384766, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=40, efSearch=100, metric=1: query_time=0.0005042552947998047, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=40, efSearch=100, metric=2: query_time=0.0010046958923339844, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=100, efSearch=10, metric=1: query_time=0.0, precision@5=0.5000000000000001 Adding embeddings to index... M=8, efConstruction=100, efSearch=10, metric=2: query_time=0.0, precision@5=0.4800000000000001 Adding embeddings to index... M=8, efConstruction=100, efSearch=50, metric=1: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=100, efSearch=50, metric=2: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=100, efSearch=100, metric=1: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=100, efSearch=100, metric=2: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=200, efSearch=10, metric=1: query_time=0.0005049705505371094, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=200, efSearch=10, metric=2: query_time=0.0, precision@5=0.4800000000000001 Adding embeddings to index... M=8, efConstruction=200, efSearch=50, metric=1: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=200, efSearch=50, metric=2: query_time=0.0010018348693847656, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=200, efSearch=100, metric=1: query_time=0.0010023117065429688, precision@5=0.5200000000000001 Adding embeddings to index... M=8, efConstruction=200, efSearch=100, metric=2: query_time=0.0010023117065429688, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=40, efSearch=10, metric=1: query_time=0.0, precision@5=0.5000000000000001 Adding embeddings to index... M=16, efConstruction=40, efSearch=10, metric=2: query_time=0.0005135536193847656, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=40, efSearch=50, metric=1: query_time=0.0010073184967041016, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=40, efSearch=50, metric=2: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=40, efSearch=100, metric=1: query_time=0.0010004043579101562, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=40, efSearch=100, metric=2: query_time=0.0005035400390625, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=100, efSearch=10, metric=1: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=100, efSearch=10, metric=2: query_time=0.0010035037994384766, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=100, efSearch=50, metric=1: query_time=0.0010008811950683594, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=100, efSearch=50, metric=2: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=100, efSearch=100, metric=1: query_time=0.0010013580322265625, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=100, efSearch=100, metric=2: query_time=0.0010001659393310547, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=200, efSearch=10, metric=1: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=200, efSearch=10, metric=2: query_time=0.0, precision@5=0.5000000000000001 Adding embeddings to index... M=16, efConstruction=200, efSearch=50, metric=1: query_time=0.0010042190551757812, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=200, efSearch=50, metric=2: query_time=0.0005049705505371094, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=200, efSearch=100, metric=1: query_time=0.0010001659393310547, precision@5=0.5200000000000001 Adding embeddings to index... M=16, efConstruction=200, efSearch=100, metric=2: query_time=0.0010008811950683594, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=40, efSearch=10, metric=1: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=40, efSearch=10, metric=2: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=40, efSearch=50, metric=1: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=40, efSearch=50, metric=2: query_time=0.0010042190551757812, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=40, efSearch=100, metric=1: query_time=0.0010037422180175781, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=40, efSearch=100, metric=2: query_time=0.0010046958923339844, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=100, efSearch=10, metric=1: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=100, efSearch=10, metric=2: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=100, efSearch=50, metric=1: query_time=0.0005071163177490234, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=100, efSearch=50, metric=2: query_time=0.005021810531616211, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=100, efSearch=100, metric=1: query_time=0.001005411148071289, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=100, efSearch=100, metric=2: query_time=0.0005042552947998047, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=200, efSearch=10, metric=1: query_time=0.0, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=200, efSearch=10, metric=2: query_time=0.0009987354278564453, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=200, efSearch=50, metric=1: query_time=0.0020186901092529297, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=200, efSearch=50, metric=2: query_time=0.0005064010620117188, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=200, efSearch=100, metric=1: query_time=0.0010051727294921875, precision@5=0.5200000000000001 Adding embeddings to index... M=32, efConstruction=200, efSearch=100, metric=2: query_time=0.0020074844360351562, precision@5=0.5200000000000001 === Best Query Time Results === 1. Time: 0.0000s, M=8, efC=40, efS=10, Metric=L2 2. Time: 0.0000s, M=8, efC=40, efS=10, Metric=L1 3. Time: 0.0000s, M=8, efC=40, efS=50, Metric=L2 4. Time: 0.0000s, M=8, efC=100, efS=10, Metric=L2 5. Time: 0.0000s, M=8, efC=100, efS=10, Metric=L1 1. Precision: 0.5200, M=8, efC=40, efS=50, Metric=L1 2. Precision: 0.5200, M=8, efC=40, efS=100, Metric=L2 3. Precision: 0.5200, M=8, efC=40, efS=100, Metric=L1 4. Precision: 0.5200, M=8, efC=100, efS=50, Metric=L2 5. Precision: 0.5200, M=8, efC=100, efS=50, Metric=L1 Best Precision@5 [(8, 40, 50, 2, 0.0010035037994384766, 0.5200000000000001), (8, 40, 100, 1, 0.0005042552947998047, 0.5200000000000001), (8, 40, 100, 2, 0.0010046958923339844, 0.5200000000000001), (8, 100, 50, 1, 0.0, 0.5200000000000001), (8, 100, 50, 2, 0.0, 0.5200000000000001)]
Task 5: Search for Dogs that look like you.ΒΆ
Use an image of yourself, and find dogs that look like you.
Alternatively: Search the internet and find at least $5$ images of people, who look similar to dogs. For each image,
- Transform the image using CLIP
- Perform a search using any faiss index and plot the results.
def dogs_that_look_like_you(embeddings, k):
# Answer:
query_path ="queries_task5"
query_embeddings, query_file_names = create_query_embeddings(query_path)
hswg_index, hswg_indices = index_and_query_hswg(embeddings, query_embeddings, k)
plot_results(query_file_names, hswg_indices, k)
return None
dogs_that_look_like_you(embeddings, k)
Task 6: Text-To-Image RetrievalΒΆ
We will now use CLIP text embeddings to find the top-3 matching images.
You are given a set of sentences. Your task is to:
- Generate CLIP-embedddings for each sentence
- Search for the Top-3 images using
IndexFlatL2index and MetricMETRIC_L1
Please refer to [1] for additional information on how to generate text-embeddings.
[1] https://github.com/openai/CLIP
You may add additional sentences to do further experiments.
k = 3 # Number of neighbors
sentences = ["a lion",
"a dog and a girl",
"a dog and a man",
"a toy and a dog",
"a group of people"]
def text_to_image(embeddings, sentences, k):
metric = faiss.METRIC_L1
# Answer:
index_flat = None
indices = None # used to store the result
# CLIP
text_tokens = clip.tokenize(sentences).to(device)
with torch.no_grad():
text_embeddings = model.encode_text(text_tokens)
text_embeddings = text_embeddings.cpu().numpy().astype('float32')
index_flat = faiss.IndexFlatL2(embeddings.shape[1])
# IndexFlatL2 is not suitable for L1 distance
# Pre process the embeddings to simulate L1 distance
# Both embeddings
processed_embeddings = np.sign(embeddings) * np.sqrt(np.abs(embeddings))
processed_text_embeddings = np.sign(text_embeddings) * np.sqrt(np.abs(text_embeddings))
# Use processed embeddings
index_flat.add(processed_embeddings)
# Use provessed embeddings
distances, indices = index_flat.search(processed_text_embeddings, k)
return indices
indices = text_to_image(embeddings, sentences, k)
if flat_indices is None:
print("You must implement Task 6")
else:
assert precision_at_k(flat_indices, 1) == 1.0
plot_results(sentences, indices, k)